{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q7qTZbUcg5VD"
      },
      "source": [
        "<h3 align=\"center\"></h3>\n",
        "\n",
        "<h1 align=\"center\">GRPO Finetuning</h1>\n",
        "\n",
        "---\n",
        "<h1 align=\"center\">Training a small math reasoner with RL</h1>\n",
        "<a href=\"https://colab.research.google.com/github/adithya-s-k/AI-Engineering.academy/blob/feat/theory-behind-finetuning/docs/LLM/HandsOnWithFinetuning/GRPO/GRPO_finetuning_notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gV4W0sp1UWKe"
      },
      "source": [
        "This notebook is an alternate version of the [GRPO demo](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) by [will brown,](https://x.com/willccbb) training llama-1b on the gsm8k math dataset.\n",
        "\n",
        "We've only implemented a series of changes to make the code more workable on Colab:\n",
        "* Replacement of llama-1b with Qwen-0.5b\n",
        "* Generation with vllm, which yields a significant speed-up. Qwen small size makes it possible to run vllm on the same gpu as the one being used for GRPO.\n",
        "* Dropping flash-attn (recurrent bug with modeling qwen, not clear why)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GPYBrSbY79we"
      },
      "source": [
        "## Setting up the models."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GOMhew_59RbM"
      },
      "source": [
        "First we install vllm. Notice that you'll have to restart the session afterwards."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "PYykgnUJ0BdB",
        "outputId": "a89e77ea-f4eb-499d-be6f-05972ac0e8a5"
      },
      "outputs": [],
      "source": [
        "!pip install vllm"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OJxgfykz93lG"
      },
      "source": [
        "Then we install trl and datasets. It has to be in this order for some reason (bug on trl if you do vllm afterwards)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ybtxR89X1YJq",
        "outputId": "0cd9cc87-a2d0-46f7-b24d-f4d5ead6c888"
      },
      "outputs": [],
      "source": [
        "!pip install trl datasets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZJNTq5HG-EYI"
      },
      "source": [
        "## Defining the RL rewards"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pbej_WBE6wLV"
      },
      "source": [
        "Now we have everything ready to set up our RL training set and reward policy."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bwqrjX1_-J3s"
      },
      "source": [
        "First we set the general prompt structure (with the reasoning tags)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "q04kVVaQ6dSe",
        "outputId": "532f8722-ea3c-49c9-eec9-82a3af148248"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "import torch\n",
        "from datasets import load_dataset, Dataset\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "from trl import GRPOConfig, GRPOTrainer\n",
        "\n",
        "# Load and prep dataset\n",
        "\n",
        "SYSTEM_PROMPT = \"\"\"\n",
        "Respond in the following format:\n",
        "<reasoning>\n",
        "...\n",
        "</reasoning>\n",
        "<answer>\n",
        "...\n",
        "</answer>\n",
        "\"\"\"\n",
        "\n",
        "XML_COT_FORMAT = \"\"\"\\\n",
        "<reasoning>\n",
        "{reasoning}\n",
        "</reasoning>\n",
        "<answer>\n",
        "{answer}\n",
        "</answer>\n",
        "\"\"\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "38AVgA19-PMk"
      },
      "source": [
        "Now we import the gsm8k dataset and restructure it to fit into a conversational prompt format:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fno7X8Fh-N6k"
      },
      "outputs": [],
      "source": [
        "def extract_xml_answer(text: str) -> str:\n",
        "    answer = text.split(\"<answer>\")[-1]\n",
        "    answer = answer.split(\"</answer>\")[0]\n",
        "    return answer.strip()\n",
        "\n",
        "def extract_hash_answer(text: str) -> str | None:\n",
        "    if \"####\" not in text:\n",
        "        return None\n",
        "    return text.split(\"####\")[1].strip()\n",
        "\n",
        "# uncomment middle messages for 1-shot prompting\n",
        "def get_gsm8k_questions(split = \"train\") -> Dataset:\n",
        "    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore\n",
        "    data = data.map(lambda x: { # type: ignore\n",
        "        'prompt': [\n",
        "            {'role': 'system', 'content': SYSTEM_PROMPT},\n",
        "            {'role': 'user', 'content': x['question']}\n",
        "        ],\n",
        "        'answer': extract_hash_answer(x['answer'])\n",
        "    }) # type: ignore\n",
        "    return data # type: ignore\n",
        "\n",
        "dataset = get_gsm8k_questions()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yi-7Hs0T-YwB"
      },
      "source": [
        "We move on now to the reward functions. The most important one is the \"correctness\" function which acts as a verifier (comparison of model completions vs. answer). The three others are formatting functions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BLCIyOzI0Gol"
      },
      "outputs": [],
      "source": [
        "# Reward functions\n",
        "def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n",
        "    responses = [completion[0]['content'] for completion in completions]\n",
        "    q = prompts[0][-1]['content']\n",
        "    extracted_responses = [extract_xml_answer(r) for r in responses]\n",
        "    print('-'*20, f\"Question:\\n{q}\", f\"\\nAnswer:\\n{answer[0]}\", f\"\\nResponse:\\n{responses[0]}\", f\"\\nExtracted:\\n{extracted_responses[0]}\")\n",
        "    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n",
        "\n",
        "def int_reward_func(completions, **kwargs) -> list[float]:\n",
        "    responses = [completion[0]['content'] for completion in completions]\n",
        "    extracted_responses = [extract_xml_answer(r) for r in responses]\n",
        "    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n",
        "\n",
        "def strict_format_reward_func(completions, **kwargs) -> list[float]:\n",
        "    \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
        "    pattern = r\"^<reasoning>\\n.*?\\n</reasoning>\\n<answer>\\n.*?\\n</answer>\\n$\"\n",
        "    responses = [completion[0][\"content\"] for completion in completions]\n",
        "    matches = [re.match(pattern, r) for r in responses]\n",
        "    return [0.5 if match else 0.0 for match in matches]\n",
        "\n",
        "def soft_format_reward_func(completions, **kwargs) -> list[float]:\n",
        "    \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
        "    pattern = r\"<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>\"\n",
        "    responses = [completion[0][\"content\"] for completion in completions]\n",
        "    matches = [re.match(pattern, r) for r in responses]\n",
        "    return [0.5 if match else 0.0 for match in matches]\n",
        "\n",
        "def count_xml(text) -> float:\n",
        "    count = 0.0\n",
        "    if text.count(\"<reasoning>\\n\") == 1:\n",
        "        count += 0.125\n",
        "    if text.count(\"\\n</reasoning>\\n\") == 1:\n",
        "        count += 0.125\n",
        "    if text.count(\"\\n<answer>\\n\") == 1:\n",
        "        count += 0.125\n",
        "        count -= len(text.split(\"\\n</answer>\\n\")[-1])*0.001\n",
        "    if text.count(\"\\n</answer>\") == 1:\n",
        "        count += 0.125\n",
        "        count -= (len(text.split(\"\\n</answer>\")[-1]) - 1)*0.001\n",
        "    return count\n",
        "\n",
        "def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n",
        "    contents = [completion[0][\"content\"] for completion in completions]\n",
        "    return [count_xml(c) for c in contents]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cuz-LQOQ-vSN"
      },
      "source": [
        "We now set the training arguments:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gGFFu5u4-3uV"
      },
      "outputs": [],
      "source": [
        "model_name = \"Qwen/Qwen2.5-0.5B-Instruct\"\n",
        "\n",
        "output_dir=\"outputs/Qwen-0.5B-GRPO\"\n",
        "run_name=\"Qwen-0.5B-GRPO-gsm8k\"\n",
        "\n",
        "training_args = GRPOConfig(\n",
        "    output_dir=output_dir,\n",
        "    run_name=run_name,\n",
        "    learning_rate=5e-6,\n",
        "    adam_beta1 = 0.9,\n",
        "    adam_beta2 = 0.99,\n",
        "    weight_decay = 0.1,\n",
        "    warmup_ratio = 0.1,\n",
        "    lr_scheduler_type='cosine',\n",
        "    logging_steps=1,\n",
        "    bf16=True,\n",
        "    per_device_train_batch_size=1,\n",
        "    gradient_accumulation_steps=4,\n",
        "    num_generations=16,\n",
        "    max_prompt_length=256,\n",
        "    max_completion_length=200,\n",
        "    num_train_epochs=1,\n",
        "    save_steps=100,\n",
        "    max_grad_norm=0.1,\n",
        "    log_on_each_node=False,\n",
        "    use_vllm=True,\n",
        "    vllm_gpu_memory_utilization=.3,\n",
        "    vllm_device=\"cuda:0\",\n",
        "    report_to=\"none\" #I'm disabling Wandb.\n",
        ")\n",
        "\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_name,\n",
        "    torch_dtype=torch.bfloat16,\n",
        "    device_map=None\n",
        ").to(\"cuda\")\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
        "tokenizer.pad_token = tokenizer.eos_token"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "REuVM0ep-4dd"
      },
      "source": [
        "And launch the actual training:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "237d44c6ba514d628d8b106c31113f51",
            "fcf49a8b18c34c83af7a4b2231d34ac9",
            "a6e9dc143bcc46a3a142f47312021c2b",
            "fc95f56a843c47e1a189b438ffd2f54d",
            "d878defc0eeb4afa84d47389262b2dd4",
            "68396d95b92f4e42b2689522a11ac3bf",
            "f910e3eb2c7e40a69b439facbcf6e3a2",
            "970af5d3c35342c0b8792230254f7b08",
            "2576dab321ea428bb4f9da69272eb97c",
            "bc184dbe11ae436db0fbdee8659626c3",
            "e41fb61a3c0c42ea8251a4711833be5e"
          ]
        },
        "id": "U1ixGbPG0Ni-",
        "outputId": "ae4d7b91-33dd-40ed-ce62-ac5c1337373c"
      },
      "outputs": [],
      "source": [
        "# use peft at your own risk; not working for me with multi-GPU training\n",
        "trainer = GRPOTrainer(\n",
        "    model=model,\n",
        "    processing_class=tokenizer,\n",
        "    reward_funcs=[\n",
        "        xmlcount_reward_func,\n",
        "        soft_format_reward_func,\n",
        "        strict_format_reward_func,\n",
        "        int_reward_func,\n",
        "        correctness_reward_func],\n",
        "    args=training_args,\n",
        "    train_dataset=dataset,\n",
        "    #peft_config=peft_config\n",
        ")\n",
        "trainer.train()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "237d44c6ba514d628d8b106c31113f51": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_fcf49a8b18c34c83af7a4b2231d34ac9",
              "IPY_MODEL_a6e9dc143bcc46a3a142f47312021c2b",
              "IPY_MODEL_fc95f56a843c47e1a189b438ffd2f54d"
            ],
            "layout": "IPY_MODEL_d878defc0eeb4afa84d47389262b2dd4"
          }
        },
        "2576dab321ea428bb4f9da69272eb97c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "68396d95b92f4e42b2689522a11ac3bf": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "970af5d3c35342c0b8792230254f7b08": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "a6e9dc143bcc46a3a142f47312021c2b": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_970af5d3c35342c0b8792230254f7b08",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_2576dab321ea428bb4f9da69272eb97c",
            "value": 1
          }
        },
        "bc184dbe11ae436db0fbdee8659626c3": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d878defc0eeb4afa84d47389262b2dd4": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e41fb61a3c0c42ea8251a4711833be5e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "f910e3eb2c7e40a69b439facbcf6e3a2": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "fc95f56a843c47e1a189b438ffd2f54d": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_bc184dbe11ae436db0fbdee8659626c3",
            "placeholder": "​",
            "style": "IPY_MODEL_e41fb61a3c0c42ea8251a4711833be5e",
            "value": "Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00&lt;00:00,  3.17it/s]\n"
          }
        },
        "fcf49a8b18c34c83af7a4b2231d34ac9": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_68396d95b92f4e42b2689522a11ac3bf",
            "placeholder": "​",
            "style": "IPY_MODEL_f910e3eb2c7e40a69b439facbcf6e3a2",
            "value": ""
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
